# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional

import mindspore as ms
from mindspore import mint

from ..models.attention_processor import Attention
from ..utils import get_logger
from ..utils.mindspore_utils import unwrap_module
from ._common import (
    _ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
    _ATTENTION_CLASSES,
    _FEEDFORWARD_CLASSES,
    _get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook

logger = get_logger(__name__)  # pylint: disable=invalid-name

_LAYER_SKIP_HOOK = "layer_skip_hook"
_original_scaled_dot_product_attention = Attention.scaled_dot_product_attention


# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
    r"""
    Configuration for skipping internal transformer blocks when executing a transformer model.

    Args:
        indices (`List[int]`):
            The indices of the layer to skip. This is typically the first layer in the transformer block.
        fqn (`str`, defaults to `"auto"`):
            The fully qualified name identifying the stack of transformer blocks. Typically, this is
            `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
            For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
            provide the correct fqn.
        skip_attention (`bool`, defaults to `True`):
            Whether to skip attention blocks.
        skip_ff (`bool`, defaults to `True`):
            Whether to skip feed-forward blocks.
        skip_attention_scores (`bool`, defaults to `False`):
            Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
            projections as the output of scaled dot product attention.
        dropout (`float`, defaults to `1.0`):
            The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
            meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
            skipped layers are fully retained, which is equivalent to not skipping any layers.
    """

    indices: List[int]
    fqn: str = "auto"
    skip_attention: bool = True
    skip_attention_scores: bool = False
    skip_ff: bool = True
    dropout: float = 1.0

    def __post_init__(self):
        if not (0 <= self.dropout <= 1):
            raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
        if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
            raise ValueError(
                "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
            )

    def to_dict(self):
        return asdict(self)

    @staticmethod
    def from_dict(data: dict) -> "LayerSkipConfig":
        return LayerSkipConfig(**data)


def skip_attention_scores(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    # If the Q sequence length does not match KV sequence length, methods like
    # Perturbed Attention Guidance cannot be used (because the caller expects
    # the same sequence length as Q, but if we return V here, it will not match).
    # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
    # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
    if query.shape[2] == value.shape[2]:
        return value
    return _original_scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)


@contextmanager
def attention_score_skip_function_mode():
    """
    MindSpore context manager that emulates HuggingFace Diffusers' AttentionScoreSkipFunctionMode (torch.overrides.TorchFunctionMode).

    In PyTorch the hook intercepts F.scaled_dot_product_attention; here we intercept Attention.sdpa.
    When query.shape[2] == value.shape[2], we return the value tensor directly, skipping the attention score computation.
    Otherwise the original attention function is executed.

    The original sdpa implementation is always restored on exit, so the net has no side-effects after the context is left.

    Example:

    ```python
    >>> with attention_score_skip_mode():
    >>> ... output = self.fn_ref.original_construct(*args, **kwargs)
    ```
    """
    global _original_scaled_dot_product_attention
    Attention.scaled_dot_product_attention = skip_attention_scores
    try:
        yield
    finally:
        Attention.scaled_dot_product_attention = _original_scaled_dot_product_attention


class AttentionProcessorSkipHook(ModelHook):
    def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
        self.skip_processor_output_fn = skip_processor_output_fn
        self.skip_attention_scores = skip_attention_scores
        self.dropout = dropout

    def new_construct(self, module: ms.nn.Cell, *args, **kwargs):
        if self.skip_attention_scores:
            if not math.isclose(self.dropout, 1.0):
                raise ValueError(
                    "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
                )
            with attention_score_skip_function_mode():
                output = self.fn_ref.original_construct(*args, **kwargs)
        else:
            if math.isclose(self.dropout, 1.0):
                output = self.skip_processor_output_fn(module, *args, **kwargs)
            else:
                output = self.fn_ref.original_construct(*args, **kwargs)
                output = mint.nn.functional.dropout(output, p=self.dropout)
        return output


class FeedForwardSkipHook(ModelHook):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = dropout

    def new_construct(self, module: ms.nn.Cell, *args, **kwargs):
        if math.isclose(self.dropout, 1.0):
            output = kwargs.get("hidden_states", None)
            if output is None:
                output = kwargs.get("x", None)
            if output is None and len(args) > 0:
                output = args[0]
        else:
            output = self.fn_ref.original_construct(*args, **kwargs)
            output = mint.nn.functional.dropout(output, p=self.dropout)
        return output


class TransformerBlockSkipHook(ModelHook):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = dropout

    def initialize_hook(self, module):
        self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
        return module

    def new_construct(self, module: ms.nn.Cell, *args, **kwargs):
        if math.isclose(self.dropout, 1.0):
            original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
            if self._metadata.return_encoder_hidden_states_index is None:
                output = original_hidden_states
            else:
                original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
                    "encoder_hidden_states", args, kwargs
                )
                output = (original_hidden_states, original_encoder_hidden_states)
        else:
            output = self.fn_ref.original_construct(*args, **kwargs)
            output = mint.nn.functional.dropout(output, p=self.dropout)
        return output


def apply_layer_skip(module: ms.nn.Cell, config: LayerSkipConfig) -> None:
    r"""
    Apply layer skipping to internal layers of a transformer.

    Args:
        module (`ms.nn.Cell`):
            The transformer model to which the layer skip hook should be applied.
        config (`LayerSkipConfig`):
            The configuration for the layer skip hook.

    Example:

    ```python
    >>> import mindspore as ms
    >>> from mindone.diffusers import apply_layer_skip, CogVideoXTransformer3DModel, LayerSkipConfig

    >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", subfolder="transformer", mindspore_dtype=ms.bfloat16)
    >>> config = LayerSkipConfig(indices=[10, 20], fqn="transformer_blocks")
    >>> apply_layer_skip(transformer, config)
    ```
    """
    _apply_layer_skip_hook(module, config)


def _apply_layer_skip_hook(module: ms.nn.Cell, config: LayerSkipConfig, name: Optional[str] = None) -> None:
    name = name or _LAYER_SKIP_HOOK

    if config.skip_attention and config.skip_attention_scores:
        raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
    if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
        raise ValueError(
            "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
        )

    if config.fqn == "auto":
        for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
            if hasattr(module, identifier):
                config.fqn = identifier
                break
        else:
            raise ValueError(
                "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
                "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
            )

    transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
    if transformer_blocks is None or not isinstance(transformer_blocks, ms.nn.CellList):
        raise ValueError(
            f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
            f"a `ms.nn.CellList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
        )
    if len(config.indices) == 0:
        raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")

    blocks_found = False
    for i, block in enumerate(transformer_blocks):
        if i not in config.indices:
            continue

        blocks_found = True

        if config.skip_attention and config.skip_ff:
            logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
            registry = HookRegistry.check_if_exists_or_initialize(block)
            hook = TransformerBlockSkipHook(config.dropout)
            registry.register_hook(hook, name)

        elif config.skip_attention or config.skip_attention_scores:
            for submodule_name, submodule in block.cells_and_names():
                if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
                    logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
                    output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
                    registry = HookRegistry.check_if_exists_or_initialize(submodule)
                    hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
                    registry.register_hook(hook, name)

        if config.skip_ff:
            for submodule_name, submodule in block.cells_and_names():
                if isinstance(submodule, _FEEDFORWARD_CLASSES):
                    logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
                    registry = HookRegistry.check_if_exists_or_initialize(submodule)
                    hook = FeedForwardSkipHook(config.dropout)
                    registry.register_hook(hook, name)

    if not blocks_found:
        raise ValueError(
            f"Could not find any transformer blocks matching the provided indices {config.indices} and "
            f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
        )
